from __future__ import absolute_import

from .myresnet_v2 import PersonResNet
from .rga_model import resnet50_rga

__factory = {
    'PersonResNet': PersonResNet,
    'resnet50_rga': resnet50_rga,
}


def model_names():
    return sorted(__factory.keys())


def create_model(name, *args, **kwargs):
    if name not in __factory:
        raise KeyError("Unknown Model:", name)
    return __factory[name](*args, **kwargs)
